Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds reduce_scatter into torchft #102

Merged
merged 8 commits into from
Feb 10, 2025
Merged

Conversation

allenwang28
Copy link
Contributor

@allenwang28 allenwang28 commented Feb 6, 2025

What does this PR do?

Partially addresses #97 by adding reduce_scatter into torchft.

Concretely, this consists of a few pieces:

Tests

Presubmits, and:

$ pytest torchft/process_group_test.py 
============================================= test session starts =============================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/allencwang/workspace/torchft
configfile: pytest.ini
plugins: typeguard-2.13.3
collected 16 items                                                                                            

torchft/process_group_test.py ................                                                          [100%]

============================================= 16 passed in 31.44s =============================================
[rank0]:[W206 14:54:24.777939032 CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

Next steps

The logic of _should_run_collective is a bit confusing, as it allows "non defined backends" like ErrorSwallowing* through, to mimic the old behavior before this change. Testing here could become a bit unwieldy as we add more collectives and so a future step could be to refactor the testing.

One nice change could be to parameterize tests by the collective. This will make potentially failing collectives more explicit and will reduce the time it takes to run individual tests. Likely can do this in the next PR.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 6, 2025
return True
return False
else: # cpu
if collective_str in ["reduce_scatter", "all_to_all"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow -- didn't realize we don't support these on Gloo, good to know! cc @c-p-i-o

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, we miss many APIs on Gloo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this approach seems nice and explicit. but is it possible to instead just try: the test, and except: some specific NYI error? (i'm not sure if we raise a consistent type of NYI exception from backends?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this approach seems nice and explicit. but is it possible to instead just try: the test, and except: some specific NYI error? (i'm not sure if we raise a consistent type of NYI exception from backends?)

Yeah this is a good idea, I modified the block as follows:

    for coll_str, args in collectives:
        try:
            coll = getattr(pg, coll_str)
            work = coll(*args)
            works[coll_str] = work
            work.wait()
            fut = work.get_future()
            fut.wait()
            # Check that all tensor arguments have the expected shapes and dtypes
            check_tensors(args)
        except RuntimeError as e:
            if f"does not support {coll_str}" in str(e):
                # Skip collectives that are not supported by the backend.
                continue
            raise e

torchft/process_group_test.py Outdated Show resolved Hide resolved
@allenwang28 allenwang28 marked this pull request as ready for review February 7, 2025 16:55
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding this!

@allenwang28
Copy link
Contributor Author

Updated the test to be simpler, so I removed the utility functions I previously added. This should remove the need for a test refactor. I wanted to parameterize by collective, but #103 shows that tests got much slower after doing this. I will deprecate #103.

I have also added an explicit NotImplementedError for reduce_scatter within ProcessGroupBabyGloo, because otherwise the exception lives within the tx queue (see here). Noticed this as test_baby_gloo_apis would fail here, and that would be a nasty issue for a downstream user. Therefore, adding this explicitly in the API is ultimately cleaner.

@allenwang28 allenwang28 requested a review from wconstab February 10, 2025 22:05
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@allenwang28 allenwang28 merged commit e55542a into pytorch:main Feb 10, 2025
6 checks passed
@allenwang28 allenwang28 deleted the collectives branch February 10, 2025 23:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants